{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "206a93a2-e9b0-4ea2-a43f-696faa83ea03",
   "metadata": {},
   "source": [
    "# 👾 PixelCNN from scratch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1af9e216-7e84-4f5b-a2db-26aca3bea464",
   "metadata": {},
   "source": [
    "In this notebook, we'll walk through the steps required to train your own PixelCNN on the fashion MNIST dataset from scratch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c1e6bbc-6f3b-48ac-a4f3-fde6f739f0ca",
   "metadata": {},
   "source": [
    "The code has been adapted from the excellent [PixelCNN tutorial](https://keras.io/examples/generative/pixelcnn/) created by ADMoreau, available on the Keras website."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6acebfa8-4546-41fd-adaa-2307c65b1b8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import numpy as np\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import datasets, layers, models, optimizers, callbacks\n",
    "\n",
    "from notebooks.utils import display"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8543166d-f4c7-43f8-a452-21ccbf2a0496",
   "metadata": {},
   "source": [
    "## 0. Parameters <a name=\"parameters\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "444d84de-2843-40d6-8e2e-93691a5393ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "IMAGE_SIZE = 16\n",
    "PIXEL_LEVELS = 4\n",
    "N_FILTERS = 128\n",
    "RESIDUAL_BLOCKS = 5\n",
    "BATCH_SIZE = 128\n",
    "EPOCHS = 150"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d65dac68-d20b-4ed9-a136-eed57095ce4f",
   "metadata": {},
   "source": [
    "## 1. Prepare the data <a name=\"prepare\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ed0fc56-d1b0-4d42-b029-f4198f78e666",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the data\n",
    "(x_train, _), (_, _) = datasets.fashion_mnist.load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b667e78c-8fa7-4e5b-a2c0-69e50166ef77",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess the data\n",
    "def preprocess(imgs_int):\n",
    "    imgs_int = np.expand_dims(imgs_int, -1)\n",
    "    imgs_int = tf.image.resize(imgs_int, (IMAGE_SIZE, IMAGE_SIZE)).numpy()\n",
    "    imgs_int = (imgs_int / (256 / PIXEL_LEVELS)).astype(int)\n",
    "    imgs = imgs_int.astype(\"float32\")\n",
    "    imgs = imgs / PIXEL_LEVELS\n",
    "    return imgs, imgs_int\n",
    "\n",
    "\n",
    "input_data, output_data = preprocess(x_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3c2b304-8385-4931-8291-9b7cc462c95e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show some items of clothing from the training set\n",
    "display(input_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ccd5cb2-8c7b-4667-8adb-4902f3fa60cf",
   "metadata": {},
   "source": [
    "## 2. Build the PixelCNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "847050a5-e4e6-4134-9bfc-c690cb8cb44d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The first layer is the PixelCNN layer. This layer simply\n",
    "# builds on the 2D convolutional layer, but includes masking.\n",
    "class MaskedConv2D(layers.Layer):\n",
    "    def __init__(self, mask_type, **kwargs):\n",
    "        super(MaskedConv2D, self).__init__()\n",
    "        self.mask_type = mask_type\n",
    "        self.conv = layers.Conv2D(**kwargs)\n",
    "\n",
    "    def build(self, input_shape):\n",
    "        # Build the conv2d layer to initialize kernel variables\n",
    "        self.conv.build(input_shape)\n",
    "        # Use the initialized kernel to create the mask\n",
    "        kernel_shape = self.conv.kernel.get_shape()\n",
    "        self.mask = np.zeros(shape=kernel_shape)\n",
    "        self.mask[: kernel_shape[0] // 2, ...] = 1.0\n",
    "        self.mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0\n",
    "        if self.mask_type == \"B\":\n",
    "            self.mask[kernel_shape[0] // 2, kernel_shape[1] // 2, ...] = 1.0\n",
    "\n",
    "    def call(self, inputs):\n",
    "        self.conv.kernel.assign(self.conv.kernel * self.mask)\n",
    "        return self.conv(inputs)\n",
    "\n",
    "    def get_config(self):\n",
    "        cfg = super().get_config()\n",
    "        return cfg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a52f7795-790e-47b0-b724-80be3e3c3666",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResidualBlock(layers.Layer):\n",
    "    def __init__(self, filters, **kwargs):\n",
    "        super(ResidualBlock, self).__init__(**kwargs)\n",
    "        self.conv1 = layers.Conv2D(\n",
    "            filters=filters // 2, kernel_size=1, activation=\"relu\"\n",
    "        )\n",
    "        self.pixel_conv = MaskedConv2D(\n",
    "            mask_type=\"B\",\n",
    "            filters=filters // 2,\n",
    "            kernel_size=3,\n",
    "            activation=\"relu\",\n",
    "            padding=\"same\",\n",
    "        )\n",
    "        self.conv2 = layers.Conv2D(\n",
    "            filters=filters, kernel_size=1, activation=\"relu\"\n",
    "        )\n",
    "\n",
    "    def call(self, inputs):\n",
    "        x = self.conv1(inputs)\n",
    "        x = self.pixel_conv(x)\n",
    "        x = self.conv2(x)\n",
    "        return layers.add([inputs, x])\n",
    "\n",
    "    def get_config(self):\n",
    "        cfg = super().get_config()\n",
    "        return cfg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19b4508f-84de-42a9-a77f-950fb493db13",
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 1))\n",
    "x = MaskedConv2D(\n",
    "    mask_type=\"A\",\n",
    "    filters=N_FILTERS,\n",
    "    kernel_size=7,\n",
    "    activation=\"relu\",\n",
    "    padding=\"same\",\n",
    ")(inputs)\n",
    "\n",
    "for _ in range(RESIDUAL_BLOCKS):\n",
    "    x = ResidualBlock(filters=N_FILTERS)(x)\n",
    "\n",
    "for _ in range(2):\n",
    "    x = MaskedConv2D(\n",
    "        mask_type=\"B\",\n",
    "        filters=N_FILTERS,\n",
    "        kernel_size=1,\n",
    "        strides=1,\n",
    "        activation=\"relu\",\n",
    "        padding=\"valid\",\n",
    "    )(x)\n",
    "\n",
    "out = layers.Conv2D(\n",
    "    filters=PIXEL_LEVELS,\n",
    "    kernel_size=1,\n",
    "    strides=1,\n",
    "    activation=\"softmax\",\n",
    "    padding=\"valid\",\n",
    ")(x)\n",
    "\n",
    "pixel_cnn = models.Model(inputs, out)\n",
    "pixel_cnn.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "442b5ffa-67a3-4b15-a342-eb1eed5e87ac",
   "metadata": {},
   "source": [
    "## 3. Train the PixelCNN <a name=\"train\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7204789a-2ad3-48bf-b7e8-00d4cab10d9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "adam = optimizers.Adam(learning_rate=0.0005)\n",
    "pixel_cnn.compile(optimizer=adam, loss=\"sparse_categorical_crossentropy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09d327fc-aff8-40e6-b390-d1bff4c06ea6",
   "metadata": {},
   "outputs": [],
   "source": [
    "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n",
    "\n",
    "\n",
    "class ImageGenerator(callbacks.Callback):\n",
    "    def __init__(self, num_img):\n",
    "        self.num_img = num_img\n",
    "\n",
    "    def sample_from(self, probs, temperature):  # <2>\n",
    "        probs = probs ** (1 / temperature)\n",
    "        probs = probs / np.sum(probs)\n",
    "        return np.random.choice(len(probs), p=probs)\n",
    "\n",
    "    def generate(self, temperature):\n",
    "        generated_images = np.zeros(\n",
    "            shape=(self.num_img,) + (pixel_cnn.input_shape)[1:]\n",
    "        )\n",
    "        batch, rows, cols, channels = generated_images.shape\n",
    "\n",
    "        for row in range(rows):\n",
    "            for col in range(cols):\n",
    "                for channel in range(channels):\n",
    "                    probs = self.model.predict(generated_images, verbose=0)[\n",
    "                        :, row, col, :\n",
    "                    ]\n",
    "                    generated_images[:, row, col, channel] = [\n",
    "                        self.sample_from(x, temperature) for x in probs\n",
    "                    ]\n",
    "                    generated_images[:, row, col, channel] /= PIXEL_LEVELS\n",
    "\n",
    "        return generated_images\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        generated_images = self.generate(temperature=1.0)\n",
    "        display(\n",
    "            generated_images,\n",
    "            save_to=\"./output/generated_img_%03d.png\" % (epoch),\n",
    "        )\n",
    "\n",
    "\n",
    "img_generator_callback = ImageGenerator(num_img=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85231056-d4a4-4897-ab91-065325a18d93",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "pixel_cnn.fit(\n",
    "    input_data,\n",
    "    output_data,\n",
    "    batch_size=BATCH_SIZE,\n",
    "    epochs=EPOCHS,\n",
    "    callbacks=[tensorboard_callback, img_generator_callback],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfb4fa72-dd2d-44c1-ad18-9c965060683e",
   "metadata": {},
   "source": [
    "## 4. Generate images <a name=\"generate\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bbd4643-be09-49ba-b7bc-a524a2f00806",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_images = img_generator_callback.generate(temperature=1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52cadb4b-ae2c-42a9-92ac-68e2131380ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "display(generated_images)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
